# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: One must import PyCuda driver first, before CVCUDA or VPF otherwise
# things may throw unexpected errors.
import pycuda.driver as cuda  # noqa: F401

import os
import sys
import json
import time
import logging
import argparse
import subprocess
import numpy as np
import pandas as pd
import multiprocessing as mp
import matplotlib.pyplot as plt

common_dir = os.path.join(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
    "common",
    "python",
)
sys.path.insert(0, common_dir)

from perf_utils import maximize_clocks, reset_clocks  # noqa: E402


class NvtxRangeTimeInfo:
    """
    A data class to hold the time information of an NVTX range.
    """

    def __init__(self, start_ms, end_ms):
        """
        :param start_ms: The start time in milliseconds for this NVTX range.
        :param end_ms: The end time in milliseconds for this NVTX range.
        """
        self.start_ms = start_ms
        self.end_ms = end_ms

    @property
    def duration_ms(self):
        """
        Returns the total execution time of this NVTX range.
        """
        return self.end_ms - self.start_ms


class NvtxRange:
    """
    A data class representing an NVTX range with its CPU and GPU time information.
    """

    def __init__(
        self, flat_name, parent_range_id, cpu_time_info=None, gpu_time_info=None
    ):
        """
        :param flat_name: The flat name of the NVTX range - as represented in NSYS reports.
        :param parent_range_id: The integer range ID of the parent range of this NVTX range.
        :param cpu_time_info: An `NvtxRangeTimeInfo` holding the CPU timing information of this NVTX range.
        :param gpu_time_info: An `NvtxRangeTimeInfo` holding the GPU timing information of this NVTX range.
        """
        self.flat_name = flat_name
        self.parent_range_id = parent_range_id
        self.cpu_time_info = cpu_time_info
        self.gpu_time_info = gpu_time_info


def parse_nvtx_pushpop_trace_json(json_path):
    """
    Parses the nvtx_pushpop_trace JSON generated by NSYS and returns a dictionary
    keyed by process_id, thread_id and range_id. The values are various fields
    important to the benchmarking process.
    :param json_path: Full path to the nvtx_pushpop_trace.json file.
    """
    #
    # The nvtx_pushpop_trace JSON has the following structure. It is a list of
    # dictionaries.
    # e.g.
    #  [ {
    #   "Start (ns)": 2372801266,
    #   "End (ns)"  : 13528369268,
    #       ...
    #   },
    # ...
    # ]
    #
    # We will store the parsed data in the range_info dictionary. The overall
    # structure of the dictionary is:
    # range_info = {
    #       process_id : {
    #           thread_id : {
    #               range_id : NvtxRange(flat_name, parent_range_id, duration_ms)
    #           }
    #       }
    # }
    #
    #
    range_info = {}

    # Check if the file was empty or not. Empty file means no ops were recorded.
    if os.stat(json_path).st_size == 0:
        return range_info

    # Read the JSON.
    with open(json_path, "r") as f:
        json_data = json.loads(f.read())

    for row in json_data:
        # Grab the necessary values from the JSON file.
        flat_name = row["Name"]
        start_ns = float(row["Start (ns)"])
        end_ns = float(row["End (ns)"])
        range_id = row["RangeId"]
        parent_range_id = row["ParentId"]
        process_id = row["PID"]
        thread_id = row["TID"]

        # Process a bit. Conversion from nano to milliseconds.
        start_ms = round(start_ns / 10**6, 4)
        end_ms = round(end_ns / 10**6, 4)
        parent_range_id = None if parent_range_id == "None" else parent_range_id

        # Save it in our dictionary at the process id and thread id level.
        if process_id not in range_info:
            range_info[process_id] = {}
        if thread_id not in range_info[process_id]:
            range_info[process_id][thread_id] = {}

        # We wills save it using the Nvtx objects.
        cpu_time_info = NvtxRangeTimeInfo(start_ms, end_ms)
        nvtx_range = NvtxRange(flat_name, parent_range_id, cpu_time_info)

        range_info[process_id][thread_id][range_id] = nvtx_range

    return range_info


def parse_nvtx_gpu_proj_trace_json(json_path):
    """
    Parses the nvtx_gpu_proj_trace JSON generated by NSYS and returns a dictionary
    keyed by process_id, thread_id and range_id. The values are various fields
    important to the benchmarking process.
    :param json_path: Full path to the nvtx_gpu_proj_trace.json file.
    """

    #
    # The nvtx_gpu_proj_trace JSON has the following structure. It is a list of
    # dictionaries.
    # e.g.
    #  [ {
    #   "Projected Start (ns)": 2372801266,
    #   "Projected Duration (ns)"  : 13528369268,
    #       ...
    #   },
    # ...
    # ]
    #
    # We will store the parsed data in the range_info dictionary. The overall
    # structure of the dictionary is:
    # range_info = {
    #       process_id : {
    #           thread_id : {
    #               range_id : NvtxRange(flat_name, parent_range_id, cpu_duration_ms, gpu_duration_ms)
    #           }
    #       }
    # }
    #
    # NOTE: Even though this report returns the cpu_duration_ms and gpu_duration_ms, it will
    #       only do so for operations which had gpu_duration_ms > 0. For pure CPU operations,
    #       this report will not even return those ranges. That is the reason why we need to
    #       query the pushpop_trace report.
    #
    range_info = {}

    # Check if the file was empty or not. Empty file means no GPU ops were recorded.
    if os.stat(json_path).st_size == 0:
        return range_info

    # Read the JSON.
    with open(json_path, "r") as f:
        json_data = json.loads(f.read())

    for row in json_data:
        # Grab the necessary values from the JSON file.
        range_id = row["RangeId"]

        if not range_id or range_id == "None":
            continue

        flat_name = row["Name"]
        cpu_start_ns = float(row["Orig Start (ns)"])
        cpu_duration_ns = float(row["Orig Duration (ns)"])
        cpu_end_ns = cpu_start_ns + cpu_duration_ns

        gpu_start_ns = float(row["Projected Start (ns)"])
        gpu_duration_ns = float(row["Projected Duration (ns)"])
        gpu_end_ns = gpu_start_ns + gpu_duration_ns

        parent_range_id = row["ParentId"]
        process_id = row["PID"]
        thread_id = row["TID"]

        # Process a bit. Conversion from nano to milliseconds.
        cpu_start_ms = round(cpu_start_ns / 10**6, 4)
        cpu_end_ms = round(cpu_end_ns / 10**6, 4)

        gpu_start_ms = round(gpu_start_ns / 10**6, 4)
        gpu_end_ms = round(gpu_end_ns / 10**6, 4)

        # Save it in our dictionary at the process id and thread id level.
        if process_id not in range_info:
            range_info[process_id] = {}
        if thread_id not in range_info[process_id]:
            range_info[process_id][thread_id] = {}

        # We wills save it using the Nvtx objects.
        cpu_time_info = NvtxRangeTimeInfo(cpu_start_ms, cpu_end_ms)
        gpu_time_info = NvtxRangeTimeInfo(gpu_start_ms, gpu_end_ms)
        nvtx_range = NvtxRange(flat_name, parent_range_id, cpu_time_info, gpu_time_info)

        range_info[process_id][thread_id][range_id] = nvtx_range

    return range_info


def expand_nvtx_range_names(range_info):
    """
    Converts a hierarchical NVTX range tree with parent-child relationship into a flat
    tree by adding the names of parent nodes in-front of all the child nodes.
    Hence, a tree like the following:
        root
            child_a
                   sub_child_a
            child_b
            child_c
                   sub_child_c

    becomes:
        root
        root.child_a
        root.child_a.sub_child_a
        root.child_b
        root.child_c
        root.child_c.sub_child-c

    :param range_info: The range_info dictionary returned by the parsing functions.
    """
    final_dict = {}

    # Loop over all the process from the range info dictionary.
    for process_id in range_info:
        if process_id not in final_dict:
            final_dict[process_id] = {}

        # Loop over all the threads from the range info dictionary.
        for thread_id in range_info[process_id]:
            if thread_id not in final_dict[process_id]:
                final_dict[process_id][thread_id] = {}

            # Loop over all the ranges from the range info dictionary.
            for range_id in range_info[process_id][thread_id]:

                # Fetch the range information.
                nvtx_range = range_info[process_id][thread_id][range_id]

                # There are two cases to consider:
                # 1. This was a root node (i.e no parent)
                # 2. This is not a root node (i.e has a parent)
                #
                my_parent_id = nvtx_range.parent_range_id
                if my_parent_id and my_parent_id != "None":
                    # This is not a root node. Get the information of its parent.
                    parent_nvtx_range = range_info[process_id][thread_id][my_parent_id]
                    # prepend parent's name in the child's name
                    new_name = os.path.join(
                        parent_nvtx_range.flat_name, nvtx_range.flat_name
                    )

                    # Most important to update our existing range info dictionary
                    # so any nested children will end up using the new, fully
                    # qualified name of this range.
                    nvtx_range.flat_name = new_name
                    range_info[process_id][thread_id][range_id] = nvtx_range

                    # And add it to our dictionary.
                    final_dict[process_id][thread_id][new_name] = nvtx_range

                else:
                    # This is a root node. Nothing else needs to be done other than
                    # simply adding this in our final dictionary.
                    final_dict[process_id][thread_id][nvtx_range.flat_name] = nvtx_range

    return final_dict


def merge_cpu_and_gpu_ranges(cpu_range_info, gpu_range_info):
    """
    Merges the CPU and GPU NVTX range information dictionaries into one such that it
    contains all the information. The keys will be flat expressions of NVTX range names
    and values will be a tuple of the CPU and GPU timings.

    NOTE: This is the function where the `NvtxRange` and `NvtxRangeTimeInfo` instances are
          used and converted into a couple of floating point duration values. In other
          words, even though the values of the `cpu_range_info` and `gpu_range_info`
          were `NvtxRange` objects, the mean can not be those (because mean of start
          and end times does not make sense.) Hence asking this function to calculate the
          mean means that only duration will be used and returned from those objects.

    :param cpu_range_info: The CPU range_info dictionary returned by the parsing functions.
    :param gpu_range_info: The GPU range_info dictionary returned by the parsing functions.
    """

    # Loop over all the keys in the cpu_range_info because it will have all the keys.
    # The gpu_range_info, may not have keys which were purely CPU code.
    all_ranges_info = {}
    for process_id in cpu_range_info:
        if process_id not in all_ranges_info:
            all_ranges_info[process_id] = {}

        for thread_id in cpu_range_info[process_id]:
            if thread_id not in all_ranges_info[process_id]:
                all_ranges_info[process_id][thread_id] = {}

            for range_name in cpu_range_info[process_id][thread_id]:
                nvtx_range = cpu_range_info[process_id][thread_id][range_name]

                gpu_time_info = NvtxRangeTimeInfo(0, 0)  # Initially it is set to zero.

                if process_id in gpu_range_info:
                    if thread_id in gpu_range_info[process_id]:
                        if range_name in gpu_range_info[process_id][thread_id]:
                            gpu_time_info = gpu_range_info[process_id][thread_id][
                                range_name
                            ].gpu_time_info

                nvtx_range.gpu_time_info = gpu_time_info

                all_ranges_info[process_id][thread_id][range_name] = nvtx_range

    return all_ranges_info


def calc_mean_ranges(all_range_info):
    """
    Calculates the mean of all NVTX ranges present in the all_range_info. Since NVTX ranges
    can be reported per process, per thread, we need to have a way to average those numbers.
    The mean here is computed by taking the average of all the numbers per process per thread.
    """
    mean_range_info = {}

    # Aggregate all the values in a list keyed by the range names.
    for process_id in all_range_info:
        for thread_id in all_range_info[process_id]:
            for range_name in all_range_info[process_id][thread_id]:
                if range_name not in mean_range_info:
                    mean_range_info[range_name] = (
                        [],
                        [],
                    )  # Mean lists for CPU and GPU time info

                cpu_time = all_range_info[process_id][thread_id][
                    range_name
                ].cpu_time_info.duration_ms
                gpu_time = all_range_info[process_id][thread_id][
                    range_name
                ].gpu_time_info.duration_ms

                mean_range_info[range_name][0].append(cpu_time)
                mean_range_info[range_name][1].append(gpu_time)

    # Replace the list with the mean value.
    for range_name in mean_range_info:
        if len(mean_range_info[range_name]):
            cpu_ranges_list = mean_range_info[range_name][0]
            gpu_ranges_list = mean_range_info[range_name][1]

            avg_cpu_time = round(sum(cpu_ranges_list) / len(cpu_ranges_list), 4)
            avg_gpu_time = round(sum(gpu_ranges_list) / len(gpu_ranges_list), 4)

            mean_range_info[range_name] = (avg_cpu_time, avg_gpu_time)
        else:
            mean_range_info[range_name] = (0, 0)

    return mean_range_info


class NumpyValuesEncoder(json.JSONEncoder):
    """
    Helps encode various Numpy data-types correctly in the JSON encoder.
    """

    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyValuesEncoder, self).default(obj)


def recurse_gather_dict(input_dict, target_dict):
    """
    Recursively gathers values of all keys of input_dict in another dictionary.
    This is useful for computing various stats on the data such as mean, median or
    std-dev of all the keys in a dictionary.
    This function collects data in a list.
    :param input_dict: The dictionary that should be used as input.
    :param target_dict: The single dictionary in which all the sums should be gathered.
    """
    assert type(input_dict) is type(target_dict)
    assert isinstance(input_dict, dict)
    # Loop over all the keys in the input dictionary.
    for key in input_dict:
        if key in ["total_items", "total_items_warmup", "total_items_minus_warmup"]:
            continue  # We skip these key.

        # Check if the value is another dictionary.
        elif isinstance(input_dict[key], dict):
            # Create this if our target_dict did not already have it.
            if key not in target_dict:
                target_dict[key] = {}

            # Recurse the same function again.
            recurse_gather_dict(input_dict[key], target_dict[key])

        # Check if the value is a list or tuple. We will store inside a list of lists.
        elif isinstance(input_dict[key], list) or isinstance(input_dict[key], tuple):
            # Create this if our target_dict did not already have it.
            if key not in target_dict:
                target_dict[key] = []
                for _ in range(len(input_dict[key])):
                    target_dict[key].append([])  # This creates list of lists.

            for i in range(len(input_dict[key])):
                target_dict[key][i].append(input_dict[key][i])

        # For anything else, we assume it was a number. We will store inside a list.
        else:
            if key not in target_dict:
                target_dict[key] = []

            target_dict[key].append(input_dict[key])


def recurse_calc_stats_dict(
    input_dict,
    compute_mean_only=False,
    compute_throughput=False,
    throughput_multiplier=1,
):
    """
    Recursively calculates various stats on the value of all keys of input_dict.
    :param input_dict: The dictionary that should be used as input.
    :param compute_mean_only: A flag indicating whether only the mean should be computed or not.
     Computes a lot of other stats (e.g. median, min, max...) if set to False.
    :param compute_throughput: A flag indicating whether throughput should be computed or not.
     Only set to True when running in parallel with all resources maximized otherwise throughput
     calculation may give incorrect results.
    :param throughput_multiplier: A number with which the throughput is multiplied to calculate the
     total throughput. Usually set to the number of parallel processes or threads executing in parallel.
    """
    # Loop over all the keys in the input dictionary.
    for key in list(input_dict.keys()):
        # Check if the value is another dictionary.
        if isinstance(input_dict[key], dict):
            recurse_calc_stats_dict(
                input_dict[key],
                compute_mean_only,
                compute_throughput,
                throughput_multiplier,
            )

        else:
            assert isinstance(input_dict[key], list)

            # Compute all stats.
            if compute_mean_only:
                stats_dict = {
                    "total_items": len(input_dict[key]),
                    "mean": round(np.mean(input_dict[key], axis=-1), 4),
                }
            else:
                stats_dict = {
                    "total_items": len(input_dict[key]),
                    "min": round(np.min(input_dict[key], axis=-1), 4),
                    "max": round(np.max(input_dict[key], axis=-1), 4),
                    "mean": round(np.mean(input_dict[key], axis=-1), 4),
                    "std": round(np.std(input_dict[key], axis=-1), 4),
                    "median": round(np.median(input_dict[key], axis=-1), 4),
                    "percentile_16": round(
                        np.percentile(input_dict[key], 16, axis=-1), 4
                    ),
                    "percentile_84": round(
                        np.percentile(input_dict[key], 84, axis=-1), 4
                    ),
                    "percentile_95": round(
                        np.percentile(input_dict[key], 95, axis=-1), 4
                    ),
                }

                if compute_throughput:
                    throughput_unit = (
                        "frames_per_second"
                        if "_per_item" in key
                        else "batches_per_second"
                    )
                    stats_dict["throughput"] = {
                        "multiplier": throughput_multiplier,
                        "unit": throughput_unit,
                        # NOTE: Minimum throughput corresponds to maximum latency.
                        "min": round(
                            1000 * throughput_multiplier / stats_dict["max"], 2
                        )
                        if stats_dict["max"] > 0
                        else 0,
                        # NOTE: Maximum throughput corresponds to minimum latency.
                        "max": round(
                            1000 * throughput_multiplier / stats_dict["min"], 2
                        )
                        if stats_dict["min"] > 0
                        else 0,
                        "mean": round(
                            1000 * throughput_multiplier / stats_dict["mean"], 2
                        )
                        if stats_dict["mean"] > 0
                        else 0,
                        "median": round(
                            1000 * throughput_multiplier / stats_dict["median"], 2
                        )
                        if stats_dict["median"] > 0
                        else 0,
                        "percentile_68_range": [
                            round(
                                1000
                                * throughput_multiplier
                                / stats_dict["percentile_84"],
                                2,
                            ),
                            round(
                                1000
                                * throughput_multiplier
                                / stats_dict["percentile_16"],
                                2,
                            ),
                        ]
                        if stats_dict["mean"] > 0
                        else [0, 0],
                        "percentile_95": round(
                            1000 * throughput_multiplier / stats_dict["percentile_95"],
                            2,
                        )
                        if stats_dict["percentile_95"] > 0
                        else 0,
                    }

            # Assign in-place.
            input_dict[key] = stats_dict


def unflatten_process_benchmark_dict(benchmark_dict, warmup_batches):
    """
    Un-flattens (i.e expands) the data present in benchmark_dict and also calculates
    additions numbers.
    """
    # This function needs to do a few different things. Here is the overall flow:
    #
    # 1. It has to expand the keys
    #       so 'run_sample/pipeline/batch_0/preprocess.cvcuda' from NSYS json
    #       becomes the following nested dictionary:
    #       run_sample : {
    #           pipeline : {
    #               batch_0 : {
    #                   preprocess.cvcuda : {"cpu_time": 0, "gpu_time": 0}
    #               }
    #           }
    #       }
    #
    # 2. Then it has to compute total of CPU and GPU times by aggregating those
    #    numbers at each level. In doing so, it has to account for warm-up batches
    #    i.e. batches whose timings should not be counted towards the total.
    #       run_sample : {
    #           pipeline : {
    #               batch_0 : {
    #                   preprocess.cvcuda  : {cpu_time: 0, gpu_time: 0}
    #                   postprocess.cvcuda : {cpu_time: 0, gpu_time: 0}
    #                   cpu_time : 0.0
    #                   gpu_time : 0.0
    #               }
    #           }
    #       }
    #
    #
    # 3. It also has to compute those times per frame/item. For this to happen, it needs
    #    the batch size information (i.e. how many items/frames were inside a batch) and
    #    also the information on which keys were "inside" a batch and which were not.
    #    These two pieces of information is taken from the benchmark.json
    #         pipeline : {
    #               batch_0 : {
    #                   preprocess.cvcuda : {cpu_time: 0, gpu_time: 0}
    #                   postprocess.cvcuda : {cpu_time: 0, gpu_time: 0}
    #                   cpu_time : 0.0
    #                   gpu_time : 0.0
    #                   cpu_time_per_item: 0.0
    #                   gpu_time_per_item: 0.0
    #               }
    #           }
    #       }
    #
    # 4. Finally, it computes various stats (e.g mean, median) of the timings from all
    #    the batches. In other words, it computes how much range X would take on an
    #    average when it is averaged across all the batches. To do this, we again use
    #    the information present inside benchmark.json and apply basic recursion math.
    #

    unfltten_data_dict = {}  # This is where we will store un-flattened data for now.

    # Maintains the total time of all warm-up batches.
    # this is keyed by the batch level prefix and values will be the time.
    total_warmup_cpu_time = {}
    total_warmup_gpu_time = {}
    # Maintain a count of total number of frames processed with counting the warm-up.
    total_items = {}
    # Maintain a count of total number of frames processed without counting the warm-up.
    total_items_minus_warmup = {}
    # Maintain pointers to the batch level sub-dictionaries. The keys will still be
    # the batch level prefix and values will be the dictionary.
    batch_dicts = {}

    # Loop over all the paths stored as keys in the input dictionary.
    for path in benchmark_dict["data"]:
        # Split the path expression by /
        parts = path.split("/")
        # Maintain a pointer to the dictionary current being traversed, initially set
        # to the empty results dictionary.
        current_dict = unfltten_data_dict
        # Loop over all but the last part, last part will be set as a value.
        for p in parts[:-1]:
            # Add the key if not already added, with a blank dict as its value.
            if p not in current_dict:
                current_dict[p] = {}

            # Update the dict pointer with nested expression.
            current_dict = current_dict[p]

        # Once all the sub-dictionaries are created, we need to assign them the correct
        # value. We need to be careful here because our expressions may have
        # nested keys such as:
        # batch_0:
        #       pre_process:
        #       post_process:
        # In the example above, we will have 3 numbers (2 for stages and 1 for overall batch)

        cpu_time = benchmark_dict["data"][path][0]
        gpu_time = benchmark_dict["data"][path][1]

        if parts[-1] not in current_dict:
            current_dict[parts[-1]] = {}

        current_dict[parts[-1]]["cpu_time"] = cpu_time
        current_dict[parts[-1]]["gpu_time"] = gpu_time

        # Now we will check at which exact level this path sits.
        # There are 3 possibilities:
        # 1. Exactly at the batch level.
        # 2. Inside the batch.
        # 3. Outside of/above the batch.
        #
        # Based on its placement, it would receive different treatments.

        # Check if this was at the batch level.
        if path in benchmark_dict["batch_info"]:
            # We are exactly at a batch level.
            batch_idx, batch_size = benchmark_dict["batch_info"][path]
            # Also find out the batch level prefix. i.e. one level above.
            batch_level_prefix = os.path.dirname(path)
            batch_dicts[batch_level_prefix] = current_dict

            # Add total items
            current_dict[parts[-1]]["total_items"] = batch_size

            # Computer per item.
            if batch_size > 0:
                current_dict[parts[-1]]["cpu_time_per_item"] = round(
                    current_dict[parts[-1]]["cpu_time"] / batch_size, 4
                )

                current_dict[parts[-1]]["gpu_time_per_item"] = round(
                    current_dict[parts[-1]]["gpu_time"] / batch_size, 4
                )

            # Pass the total_items information to all the children of this batch
            # i.e. keys which were present in inside_batch_info
            # unless they already had it before (i.e. very weird case where someone
            # inserted a batch in a batch with different inner batch size).
            def _recurse_update_children_total_items(in_dict):
                for k in list(in_dict.keys()):
                    if isinstance(in_dict[k], dict):
                        _recurse_update_children_total_items(in_dict[k])
                    else:
                        # Add total items if not already present.
                        if "total_items" not in in_dict:
                            in_dict["total_items"] = batch_size
                            in_dict["cpu_time_per_item"] = round(
                                in_dict["cpu_time"] / batch_size, 4
                            )
                            in_dict["gpu_time_per_item"] = round(
                                in_dict["gpu_time"] / batch_size, 4
                            )

            # Apply it.
            if batch_size > 0:
                _recurse_update_children_total_items(current_dict[parts[-1]])

            # Maintain global counts of various batch level stats
            # for example, counting the total items seen at this
            # batch level.
            if batch_level_prefix not in total_items:
                total_items[batch_level_prefix] = 0
                total_items_minus_warmup[batch_level_prefix] = 0
                total_warmup_cpu_time[batch_level_prefix] = 0
                total_warmup_gpu_time[batch_level_prefix] = 0

            # Add to the totals at this batch level.
            total_items[batch_level_prefix] += batch_size

            # Check if this batch was not in the warm-up period.
            # Batches from the front and end are ignored that fall under the warm-up period.
            if batch_size > 0:
                if (
                    batch_idx + 1 > warmup_batches
                    and (batch_idx + 1 + warmup_batches)
                    <= benchmark_dict["meta"]["total_batches"][batch_level_prefix]
                ):
                    # This is a non-warmup batch.
                    total_items_minus_warmup[batch_level_prefix] += batch_size

                else:
                    # This is a warm-up batch. Add its timings so that we can
                    # subtract it later from the totals.
                    total_warmup_cpu_time[batch_level_prefix] += current_dict[
                        parts[-1]
                    ]["cpu_time"]
                    total_warmup_gpu_time[batch_level_prefix] += current_dict[
                        parts[-1]
                    ]["gpu_time"]

        elif path in benchmark_dict["inside_batch_info"]:
            # We could be inside a batch. Nothing to do here. We won't have the total_items yet
            # at this level. We will update it when we come on a parent level.
            pass
        else:
            # We are one or more levels outside/above the batch level.
            # We will need to correctly pass the total items stats here.
            current_dict[parts[-1]]["cpu_time_per_item"] = 0
            current_dict[parts[-1]]["gpu_time_per_item"] = 0

            # For cases where we are one level above the batch level, we can directly use the
            # stats stored in our dictionaries. Path's value here will be equal to batch_level_prefix

            if path in total_items:
                # We are exactly one level outside/above the batch level.
                if total_items[path] > 0:
                    current_dict[parts[-1]]["cpu_time_per_item"] = round(
                        current_dict[parts[-1]]["cpu_time"] / total_items[path],
                        4,
                    )
                    current_dict[parts[-1]]["gpu_time_per_item"] = round(
                        current_dict[parts[-1]]["gpu_time"] / total_items[path],
                        4,
                    )
                current_dict[parts[-1]]["total_items"] = total_items[path]

            else:
                # We are more than one level outside/above the batch level.
                # We will need to report stats summing up all the nested batch levels.
                total_items_above_level = 0

                for k in total_items:
                    if k.startswith(path):
                        total_items_above_level += total_items[k]

                if total_items_above_level > 0:
                    current_dict[parts[-1]]["cpu_time_per_item"] = round(
                        current_dict[parts[-1]]["cpu_time"] / total_items_above_level,
                        4,
                    )
                    current_dict[parts[-1]]["gpu_time_per_item"] = round(
                        current_dict[parts[-1]]["gpu_time"] / total_items_above_level,
                        4,
                    )
                current_dict[parts[-1]]["total_items"] = total_items_above_level

    # Add warm-up related keys exactly at the batch level dictionaries.
    # Warm-up time is the time taken by the warm-up number of batches
    # at the beginning and at the end of the pipeline.
    # So for any warm-up batches > 0, we add up their run times and
    # subtract those from the total run time at the end.
    for batch_level_prefix in batch_dicts:
        batch_dict = batch_dicts[batch_level_prefix]

        batch_dict["total_items_warmup"] = (
            batch_dict["total_items"] - total_items_minus_warmup[batch_level_prefix]
        )

        batch_dict["total_items_minus_warmup"] = total_items_minus_warmup[
            batch_level_prefix
        ]

        batch_dict["cpu_time_minus_warmup"] = round(
            (batch_dict["cpu_time"] - total_warmup_cpu_time[batch_level_prefix]), 4
        )
        batch_dict["gpu_time_minus_warmup"] = round(
            (batch_dict["gpu_time"] - total_warmup_gpu_time[batch_level_prefix]), 4
        )

        batch_dict["cpu_time_minus_warmup_per_item"] = 0
        batch_dict["gpu_time_minus_warmup_per_item"] = 0

        if total_items_minus_warmup[batch_level_prefix] > 0:
            batch_dict["cpu_time_minus_warmup_per_item"] = round(
                batch_dict["cpu_time_minus_warmup"]
                / total_items_minus_warmup[batch_level_prefix],
                4,
            )
            batch_dict["gpu_time_minus_warmup_per_item"] = round(
                batch_dict["gpu_time_minus_warmup"]
                / total_items_minus_warmup[batch_level_prefix],
                4,
            )

    # The processing is over. So we assign the expanded version of data into the
    # original benchmark dictionary.
    benchmark_dict["data"] = unfltten_data_dict

    # Finally, process the batches to calculate various stats on the batch timings.
    # i.e. how much did range X took on an average across all the batches.
    # Again, we will not use any batches that are warm-up batches in this calculation.
    # For this to happen, we need to rely on the batch_info keys. Those are the
    # markers telling us what constitutes as "inside a batch". Then we build a union
    # of all the keys inside the batch, sum it up and find out the average.
    #
    # NOTE: Although not required, it would be good idea that to name the ranges inside
    #       various batches the same name.
    #       i.e. decode of batch_0 and batch_1 both be called 'decode'.
    #       It won't be an issue if that is not the case. Just that the average value
    #       will be averaged on non-uniform number of samples.
    #       e.g. if there is a range that is only used during last 3 batches, its mean
    #            value will be a mean over 3 samples compared to a range which is used
    #            during all the batches. Our division logic takes care of properly
    #            dividing with the current count anyway.
    #
    data_stats = {}
    for batch_range_name in benchmark_dict["batch_info"]:
        batch_idx, batch_size = benchmark_dict["batch_info"][batch_range_name]
        # Next, we find out the batch level prefix. This is the key in which
        # the batches are nested. One profiling session can have multiple levels
        # at which batches may be used.
        # e.g.
        # program_X:
        #   method_A:
        #       batch_1
        #       batch_2
        #   method_B:
        #       batch_1
        #       batch_2
        #
        # We need to find mean at these two levels (i.e. method_A and method_B)
        # in this case.
        # programA/method_A and program_A/method_B are the batch level prefix here.
        # We can easily get those by using the dirname method since those are like
        # the directory names in a path.
        batch_level_prefix = os.path.dirname(batch_range_name)

        if (
            batch_size > 0
            and batch_idx + 1 > warmup_batches
            and (batch_idx + 1 + warmup_batches)
            <= benchmark_dict["meta"]["total_batches"][batch_level_prefix]
        ):
            # Keep on updating the data_stats dictionary. This will
            # create a dictionary that is union of all the dictionaries of the batch level.
            nested_keys = batch_range_name.split("/")
            source_dict = benchmark_dict["data"]
            for k in nested_keys:
                source_dict = source_dict[k]

            # Need to recursively update the data_stats based on
            # the source_dict. We will sum the values up.
            if batch_level_prefix not in data_stats:
                data_stats[batch_level_prefix] = {}

            recurse_gather_dict(source_dict, target_dict=data_stats[batch_level_prefix])

    # Once all the numbers are gathered, we need to divide by the length to figure
    # out the mean values.
    recurse_calc_stats_dict(data_stats)
    benchmark_dict["data_stats_minus_warmup"] = data_stats

    # Remove the batch_info and inside_batch_info keys as they are no longer needed.
    del benchmark_dict["inside_batch_info"]


def benchmark_script(
    process_idx,
    device_id,
    output_dir,
    warmup_batches,
    script,
    args,
):
    """
    Main function responsible for running an arbitrary python script and benchmarking it.
    :param process_idx: The 0-based index of this process.
    :param device_id: The GPU device id to use.
    :param output_dir: The output directory to use to store artifacts.
    :param warmup_batches: The numbers of batches that should be ignored from benchmarking.
    :param script: The python script to execute.
    :param args: Any optional command line arguments that should be passed to the script.
    """

    # Make a copy of the environment variables and add our own env-vars to it.
    my_env = os.environ.copy()
    # Change the CUDA visible devices for this process.
    my_env["CUDA_VISIBLE_DEVICES"] = str(device_id)
    # Add the benchmark flag so that perf_utils knows that this is a benchmark run.
    my_env["BENCHMARK_PY"] = "1"

    # Set a path to store the SQLITE report created by NSYS.
    out_sqlite_path = os.path.join(output_dir, "perf_report")
    # Set a path to the benchmark.json created by the script.
    benchmark_json_path = os.path.join(output_dir, "benchmark.json")

    # Remove any existing benchmark.json files.
    if os.path.isfile(benchmark_json_path):
        os.remove(benchmark_json_path)

    # Setup the command that will launch nsys and ask it to benchmark the script
    # that we were interested in.
    nsys_root_path = "/opt/nvidia/nsight-systems/2024.2.1/"
    nsys_binary_path = os.path.join(nsys_root_path, "bin/nsys")
    nsys_reports_path = os.path.join(nsys_root_path, "target-linux-x64/reports")
    nsys_gpu_proj_trace_report_path = os.path.join(
        nsys_reports_path, "nvtx_gpu_proj_trace"
    )
    nsys_pushpop_trace_report_path = os.path.join(
        nsys_reports_path, "nvtx_pushpop_trace"
    )

    if not os.path.isfile(nsys_binary_path):
        raise ValueError(
            "Unable to locate nsys binary at %s. Make sure you have nsight-systems 2024.2.1 installed."
            % nsys_binary_path
        )

    cmd = [
        nsys_binary_path,
        "profile",
        "--export",
        "sqlite",
        "-o",
        out_sqlite_path,
        "--force-overwrite",
        "true",
        "--trace",
        "cuda,nvtx",
        "--trace-fork-before-exec=true",
        "--gpu-video-device",
        "all",
        sys.executable,
        script,
        *args,
    ]
    # Start the sub-process and wait for its completion.
    subproc = subprocess.Popen(cmd, stdout=None, stderr=None, env=my_env)
    subproc.wait()

    # Check if the subprocess was completed successfully. Proceed further only if yes.
    if subproc.returncode:
        return subproc.returncode, output_dir

    # Also if the script actually ran in the benchmark mode using our own perf_utils, it
    # must have generated a benchmark.json file for this process. That file may not be
    # it its final form (i.e. may still contain zero as perf values) but at-least it must be present.
    if not os.path.isfile(benchmark_json_path):
        logging.error(
            "benchmark.json was not found for process: %d at: %s. "
            "Did the script forget to call CvCudaPerf.finalize()?"
            % (process_idx, benchmark_json_path)
        )
        return 1, output_dir

    # Open and read the benchmark.json.
    with open(benchmark_json_path, "r") as f:
        benchmark_dict = json.loads(f.read())

    # Second step is to generate a JSON file from the SQLITE database generated by nsys.
    # We do this using nsys's stat command. We run it and wait for its completion.
    cmd2 = [
        nsys_binary_path,
        "stats",
        "--force-overwrite",
        "true",
        "-r",
        "%s,%s" % (nsys_gpu_proj_trace_report_path, nsys_pushpop_trace_report_path),
        "-f",
        "json",
        "-o",
        out_sqlite_path,
        out_sqlite_path + ".sqlite",
    ]
    subproc = subprocess.Popen(cmd2, stdout=None, stderr=None, env=my_env)
    subproc.wait()

    if subproc.returncode:
        return subproc.returncode, output_dir

    # Third step is to parse, process and merge the 2 JSONs generated by nsys above
    # One JSON is for CPU push-pop times and the other is for the GPU times.
    cpu_range_info = parse_nvtx_pushpop_trace_json(
        os.path.join(output_dir, "perf_report_nvtx_pushpop_trace.json")
    )
    gpu_range_info = parse_nvtx_gpu_proj_trace_json(
        os.path.join(output_dir, "perf_report_nvtx_gpu_proj_trace.json")
    )
    # Process
    cpu_range_info = expand_nvtx_range_names(cpu_range_info)
    gpu_range_info = expand_nvtx_range_names(gpu_range_info)
    # Merge
    all_range_info = merge_cpu_and_gpu_ranges(cpu_range_info, gpu_range_info)
    # Calculate averages across processes/threads
    mean_ranges_info = calc_mean_ranges(all_range_info)

    # Final step is to pull the data from the all_range_info we generated above and fill
    # it in the benchmark_dict.
    for k in mean_ranges_info:
        # Prepare the key to look for by prepending the obj_name value to the key.
        # obj_name is the key that sits at the root level and NSYS does not know about it.
        prepended_key = os.path.join(benchmark_dict["meta"]["obj_name"], k)
        if prepended_key in benchmark_dict["data"]:
            benchmark_dict["data"][prepended_key] = mean_ranges_info[k]

    # Un-flatten the benchmark dictionary and compute additional stats such as
    # per batch and per frame numbers. We rely on our own data that was saved in the
    # batch info key when benchmark.json was first created by perf_utils. This calculation
    # is not done by NSYS.
    unflatten_process_benchmark_dict(benchmark_dict, warmup_batches)

    # Write the updated benchmark dictionary.
    with open(benchmark_json_path, "w") as f:
        f.write(json.dumps(benchmark_dict, indent=4, cls=NumpyValuesEncoder))

    # Delete the temporary files.
    os.remove(os.path.join(output_dir, "perf_report_nvtx_pushpop_trace.json"))
    os.remove(os.path.join(output_dir, "perf_report_nvtx_gpu_proj_trace.json"))
    os.remove(os.path.join(output_dir, "perf_report.sqlite"))

    return 0, output_dir


def monitor_gpu_metrics(list_of_device_ids, terminate_event, gpu_metrics_info):
    """
    Monitors various GPU metrics directly from NVIDIA-smi. These metrics are not accessible
    from NSYS as of now. The monitoring stays on until an event is received.
    :param list_of_device_ids: A list of string values of the GPU-ids that are being used in this
    benchmark run.
    :param terminate_event: An event that can mark an end of the monitoring process.
    :param gpu_metrics_info: A multiprocessing share dictionary to store the results of monitoring.
    """
    # Initialize the gpu_metrics_info dictionary for the first time. We will save the
    # following pieces of information per GPU:
    #
    # 1) The total GPU power drawn in Watts
    # 2) The GPU utilization in %.
    # 3) The GPU temperature
    # 4) The clock event reasons active. (shows any reasons why the GPU clock was changed.)

    # We will work in a local dictionary first. Only when we are done that we would
    # transfer its contents to the mp managed dictionary. Because otherwise the mp
    # managed dictionary has no way of knowing when a nested key-value changes and it
    # won't update/save it.
    gpu_metrics_info_local = {
        "power.draw.watts": {},
        "utilization.gpu": {},
        "temperature.gpu": {},
        "clocks_event_reasons": {},
        "clocks.current.graphics": {},
    }
    for device_id in list_of_device_ids:
        for k in gpu_metrics_info_local:
            gpu_metrics_info_local[k]["GPU: %s" % device_id] = []

    # Begin the monitoring loop. Continue till we are asked to stopped by the event.
    while not terminate_event.is_set():
        # Use nvidia-smi to get power draw and GPU utilization numbers for all GPUs.
        proc_ret = subprocess.run(
            [
                "nvidia-smi",
                "-i=%s" % ",".join(list_of_device_ids),
                "--query-gpu=power.draw,utilization.gpu,temperature.gpu,"
                "clocks_event_reasons.active,clocks.current.graphics",
                "--format=csv,nounits,noheader",
            ],
            stdout=subprocess.PIPE,
        )
        if proc_ret.returncode == 0:
            outputs = proc_ret.stdout.decode().strip().split("\n")
            for idx, device_id in enumerate(list_of_device_ids):
                (
                    power_draw,
                    gpu_util,
                    gpu_temp,
                    clocks_event_reasons,
                    graphics_clock,
                ) = outputs[idx].split(",")
                power_draw = float(power_draw)
                gpu_util = float(gpu_util)
                gpu_temp = float(gpu_temp)
                clocks_event_reasons = int(clocks_event_reasons, 16)  # Hex to Decimal
                graphics_clock = float(graphics_clock)

                gpu_metrics_info_local["power.draw.watts"][
                    "GPU: %s" % device_id
                ].append(power_draw)
                gpu_metrics_info_local["utilization.gpu"]["GPU: %s" % device_id].append(
                    gpu_util
                )
                gpu_metrics_info_local["temperature.gpu"]["GPU: %s" % device_id].append(
                    gpu_temp
                )
                gpu_metrics_info_local["clocks_event_reasons"][
                    "GPU: %s" % device_id
                ].append(clocks_event_reasons)
                gpu_metrics_info_local["clocks.current.graphics"][
                    "GPU: %s" % device_id
                ].append(graphics_clock)
        else:
            for device_id in list_of_device_ids:
                for k in gpu_metrics_info_local.keys():
                    gpu_metrics_info_local[k]["GPU: %s" % device_id].append(0.0)

        # Sleep a bit
        time.sleep(0.3)  # 300 milliseconds

    # Update the mp managed dictionary
    gpu_metrics_info.update(gpu_metrics_info_local)


def plot_gpu_metrics(gpu_metrics_info, output_dir):
    """
    Plots GPU metrics as a matplotlib plot.
    """
    for metric_name in gpu_metrics_info:
        # Create pandas data frame.
        df = pd.DataFrame(gpu_metrics_info[metric_name])
        ax = df.plot(title=metric_name)
        ax.set_xlabel("Execution time")
        ax.set_ylabel(metric_name)
        fig = ax.get_figure()
        fig.savefig(os.path.join(output_dir, "plot.%s.jpg" % metric_name))
        plt.close(fig)


def main():
    parser = argparse.ArgumentParser(
        "Performance benchmarking script for CV-CUDA.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "-np",
        "--num_processes",
        type=int,
        default=1,
        help="The number of processes to spawn.",
    )

    parser.add_argument(
        "-ng", "--num_gpus", type=int, default=1, help="The number of GPUs to use."
    )

    parser.add_argument(
        "-go",
        "--gpu_offset_id",
        type=int,
        default=0,
        help="Offset for the GPU ids, assuming the GPUs are stacked together in a multi-GPU node.",
    )

    parser.add_argument(
        "-ll",
        "--log_level",
        type=str,
        choices=["info", "error", "debug", "warning"],
        default="info",
        help="Sets the desired logging level. Affects the std-out printed by the "
        "sample when it is run.",
    )

    parser.add_argument(
        "-o",
        "--output_dir",
        default="/tmp",
        type=str,
        help="The folder where the output results should be stored.",
    )

    parser.add_argument(
        "-w",
        "--warmup_batches",
        type=int,
        default=1,
        help="Sets the number of batches that should be ignored from being counted in "
        "the totals of the performance benchmarking numbers. These many batches are ignored"
        " from the front and the end.",
    )

    parser.add_argument(
        "-m",
        "--maximize_clocks",
        action="store_true",
        help="Maximizes the GPU clocks and power limits before running the benchmark. "
        "Clocks are not maximized by default.",
    )

    parser.add_argument(
        "script",
        help="The script that you want to benchmark.",
    )

    parser.add_argument(
        "args",
        nargs=argparse.REMAINDER,
        help="Any command-line arguments that should be passed to the script being benchmarked.",
    )

    args = parser.parse_args()

    if not os.path.isfile(args.script):
        raise ValueError("Script file does not exist at: %s" % args.script)

    logging.basicConfig(
        format="[%(name)s:%(lineno)d] %(asctime)s %(levelname)-6s %(message)s",
        level=getattr(logging, args.log_level.upper()),
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger = logging.getLogger("benchmark.py")

    # Check and raise error if the output directory exists in the child process's args.
    if "-o" in args.args or "--output_dir" in args.args:
        raise ValueError(
            "The output directory must only be specified once for benchmark.py. "
            "Do not specify it in the command line arguments of the script to be benchmarked."
        )

    # Maximize the clocks.
    clocks_info = []
    all_device_ids = []
    for gpu_idx in range(args.num_gpus):
        device_id = args.gpu_offset_id + gpu_idx

        all_device_ids.append(str(device_id))
        if args.maximize_clocks:
            (
                did_maximize_clocks,
                was_persistence_mode_on,
                current_power_limit,
            ) = maximize_clocks(logger, device_id)
            clocks_info.append((was_persistence_mode_on, current_power_limit))

    # We will start multiple processes, per num_processes per num_gpus in Pool to run the benchmarks.
    pool = mp.Pool()
    # Create an event to signal other processes (e.g. monitor_gpu_metrics to stop when main pool has stopped)
    pool_terminate_event = mp.Event()
    # Create a shared dict to retrieve the results of monitor_gpu_metrics
    mp_manager = mp.Manager()
    gpu_metrics_info = mp_manager.dict()
    # Finally allocate a list to store the Pool's process's results.
    results = []

    # Begin by starting one process to keep on monitoring various GPU metrics that are not available via NSYS.
    gpu_metric_monitor_proc = mp.Process(
        target=monitor_gpu_metrics,
        args=(all_device_ids, pool_terminate_event, gpu_metrics_info),
    )
    gpu_metric_monitor_proc.start()

    # Then we start the multiprocessing Pool.
    for gpu_idx in range(args.num_gpus):
        for process_idx in range(args.num_processes):
            # Since each the output of each process needs to be stored in a different directory,
            # we will create the directory based on the process index and the GPU index.
            proc_output_dir = os.path.join(
                args.output_dir, "proc_%d_gpu_%d" % (process_idx, gpu_idx)
            )
            if not os.path.exists(proc_output_dir):
                os.makedirs(proc_output_dir)

            # Supply additional command-line arguments to make sure each process
            # behaves correctly.
            proc_device_id = str(args.gpu_offset_id + gpu_idx)
            proc_args = args.args.copy()
            # The following will make sure that it inserts the additional args
            # only at the beginning of the list so that it doesn't interfere with a
            # potentially argparse.REMAINDER style arg present at the end.

            # Need to set this to 0 because once CUDA_VISIBLE_DEVICES is used,
            # the process won't be able to see other gpus
            proc_args[:0] = ["--device_id", "0"]
            proc_args[:0] = ["--output_dir", proc_output_dir]
            # Start the pool.
            result = pool.apply_async(
                benchmark_script,
                args=(
                    process_idx,
                    proc_device_id,
                    proc_output_dir,
                    args.warmup_batches,
                    args.script,
                    proc_args,
                ),
            )
            logger.info("Launched process: %d. gpu-idx: %d" % (process_idx, gpu_idx))
            results.append(result)

    # Close the pool and wait everything to finish.
    pool.close()
    pool.join()

    # Set the terminate event so other processes know that the Pool has finished.
    pool_terminate_event.set()
    # Wait for the gpu_metric_monitor process to finish.
    gpu_metric_monitor_proc.join()

    # Reset the clocks.
    if args.maximize_clocks:
        for gpu_idx in range(args.num_gpus):
            device_id = args.gpu_offset_id + gpu_idx

            was_persistence_mode_on, current_power_limit = clocks_info[gpu_idx]

            reset_clocks(
                logger,
                device_id,
                was_persistence_mode_on,
                current_power_limit,
            )
    else:
        logger.warning("Clocks were not maximized during this run.")

    # We must create a copy of gpu_metrics_info to detach it from multiprocessing.
    gpu_metrics_info = gpu_metrics_info.copy()

    # Plot the GPU metrics.
    plot_gpu_metrics(gpu_metrics_info, args.output_dir)

    # Now we need to :
    # 1) Write the gpu_metrics_info in to the benchmark.json files stored per
    #    process and
    # 2) Calculate various stats at the all processes level.
    #    e.g. If we ran 1 or more processes, there will be a benchmark_mean.json
    #    created in the output root folder with mean and other stats computed from
    #    all benchmark.json files of all the processes.
    #    This can only be done if all processes finished without error.
    #    So we will check that first and if that is the case, we will
    #    read their benchmark.json data in a list to later calculate various stats.
    all_data_dicts = []
    for r in results:
        # Grab the return result from the pool.
        proc_ret_code, proc_output_dir = r.get()
        if proc_ret_code:
            # Any non-zero return code mean the process failed.
            raise Exception(
                "Process: %d on gpu: %d exited with a non-zero return code: %d"
                % (process_idx, gpu_idx, proc_ret_code)
            )
        else:
            # Zero return code means success. Read the benchmark.json.
            with open(os.path.join(proc_output_dir, "benchmark.json"), "r") as f:
                benchmark_dict = json.loads(f.read())

            # Update this benchmark dict with GPU metrics for this GPU id.
            for metric_name in gpu_metrics_info:
                device_id_of_this_proc = benchmark_dict["meta"]["device"]["id"]
                benchmark_dict["gpu_metrics"][metric_name] = gpu_metrics_info[
                    metric_name
                ]["GPU: %d" % device_id_of_this_proc]

            with open(os.path.join(proc_output_dir, "benchmark.json"), "w") as f:
                f.write(json.dumps(benchmark_dict, indent=4, cls=NumpyValuesEncoder))

            # Append to our list of data dict.
            all_data_dicts.append(benchmark_dict["data"])

    # 1) Compute mean of the data field from all processes...
    data_mean_all_procs = {}
    # First recursively collect all values from all the data dictionaries of all processes.
    for data_dict in all_data_dicts:
        recurse_gather_dict(data_dict, data_mean_all_procs)
    # And then compute just the mean over this.
    recurse_calc_stats_dict(data_mean_all_procs, compute_mean_only=True)

    # 2) Compute various stats of the data_stats_minus_warmup field from all processes...
    # Now compute all the stats (such as mean, median etc) for all processes from all numbers.
    # NOTE: We have already computed these stats per process in the benchmark.json's
    # data_stats_minus_warmup field. This time, we want to do it over all the processes. Instead
    # of taking mean of those numbers, we will calculate the freshly, combining all data points.
    # This results in much accurate statistics.
    # We will use last process's benchmark_dict to use query some important fields such as
    # batch_info and total_batches etc. This assumes that all processes ran the same code.
    data_stats_all_procs = {}
    for batch_range_name in benchmark_dict["batch_info"]:
        batch_idx, batch_size = benchmark_dict["batch_info"][batch_range_name]
        # Next, we find out the batch level prefix. This is the key in which
        # the batches are nested. One profiling session can have multiple levels
        # at which batches may be used.
        # e.g.
        # program_X:
        #   method_A:
        #       batch_1
        #       batch_2
        #   method_B:
        #       batch_1
        #       batch_2
        #
        # We need to find mean at these two levels (i.e. method_A and method_B)
        # in this case.
        # programA/method_A and program_A/method_B are the batch level prefix here.
        # We can easily get those by using the dirname method since those are like
        # the directory names in a path.
        batch_level_prefix = os.path.dirname(batch_range_name)

        if (
            batch_size > 0
            and batch_idx + 1 > args.warmup_batches
            and (batch_idx + 1 + args.warmup_batches)
            <= benchmark_dict["meta"]["total_batches"][batch_level_prefix]
        ):
            # Keep on updating the data_stats dictionary. This will create a dictionary
            # that is union of all the dictionaries at the batch level for all processes.
            nested_keys = batch_range_name.split("/")

            for data_dict in all_data_dicts:
                source_dict = data_dict

                # Go deep down the nested key path from the root.
                for k in nested_keys:
                    source_dict = source_dict[k]

                # Need to recursively update the data_stats_all_procs based on
                # the source_dict. We will sum the values up.
                if batch_level_prefix not in data_stats_all_procs:
                    data_stats_all_procs[batch_level_prefix] = {}

                recurse_gather_dict(
                    source_dict, target_dict=data_stats_all_procs[batch_level_prefix]
                )

    # Once all the data points are gathered, we need to divide by the length to figure
    # out the mean values.
    recurse_calc_stats_dict(
        data_stats_all_procs,
        compute_throughput=True,
        throughput_multiplier=args.num_gpus * args.num_processes,
    )

    # 3). Compute stats of of all GPU metrics for all GPUs involved.
    gpu_metrics_all_procs = {}
    for metric_name in gpu_metrics_info:
        gpu_metrics_all_procs[metric_name] = []
        for device_id in gpu_metrics_info[metric_name]:
            # Gather all
            gpu_metrics_all_procs[metric_name].extend(
                gpu_metrics_info[metric_name][device_id]
            )

    # Compute stats.
    recurse_calc_stats_dict(
        gpu_metrics_all_procs,
    )

    mean_benchmark_data = {
        "data_mean_all_procs": data_mean_all_procs,
        "data_stats_minus_warmup_all_procs": data_stats_all_procs,
        "gpu_metrics_all_procs": gpu_metrics_all_procs,
        "meta": {"args": {}},
    }
    for arg in vars(args):
        mean_benchmark_data["meta"]["args"][arg] = getattr(args, arg)

    # Write it in a file.
    mean_benchmark_json_path = os.path.join(args.output_dir, "benchmark_mean.json")
    with open(mean_benchmark_json_path, "w") as f:
        f.write(json.dumps(mean_benchmark_data, indent=4, cls=NumpyValuesEncoder))
        logger.info(
            "Benchmarking completed successfully. Results saved at: %s"
            % mean_benchmark_json_path
        )


if __name__ == "__main__":
    main()
